import torch
import numpy as np
from image_synthesis.modeling.codecs.base_codec import BaseCodec


class PixelCodec(BaseCodec):
    """
    This class is used to perform encode and decode
    of image pixels, which is similar to IGPT models
    """
    def __init__(
        self,
        token_shape,
        pixel_kmens_center_path='data/kmeans_centers.npy',
    ):
        super().__init__()
        self.token_shape = token_shape # [h, w]
        self.pixel_kmens_center_path = pixel_kmens_center_path

        # load pixel centers
        # note that the centers is borrowed from IGPT, 
        # and they are normalized to [-1, 1]
        centers = np.load(pixel_kmens_center_path) # 512 x 3
        centers = torch.from_numpy(centers).view(1, 1, -1, 3) # 1 x 1 x 512 x 3
        centers = torch.round(127.5 * (1 + centers)) # map to origin [0-255]
        self.register_buffer('centers', centers)

        self.trainable = False
        self._set_trainable()

    @property
    def device(self):
        return self.centers.device


    def get_tokens(self, x, mask=None, enc_with_mask=True, **kwargs):
        x = x.to(self.device) # B x 3 x H x W, the given image with the value in [0, 255]
        
        # get the distance between pixels and centers
        b, _, h, w = x.shape
        x = x.flatten(2).permute(0, 2, 1) #B x HW x 3
        x = x.unsqueeze(dim=2) # B x HW x 1 x 3
        dist = (x - self.centers).pow(2).sum(dim=-1) # B x Hw x C

        # get the index of nearest centers
        _, index = dist.min(dim=-1) # index: B x HW

        if mask is not None: # mask should be B x 1 x H x W
            mask = mask.flatten(1).to(self.device) # B x HW
            target = index
            if enc_with_mask:
                token = index * mask.to(index.dtype)
            else:
                token = index.clone()
            output = {
                'target': target,
                'token': token,
                'mask': mask.to(torch.bool)
            }
        else:
            output = {
                'token': index
            }
        return output

    def decode(self, token):
        assert self.token_shape is not None
        token = token.view(-1, *self.token_shape) # B x H x W
        rec = torch.nn.functional.embedding(token, self.centers.view(-1, 3)) # B x H x W x 3
        return rec.permute(0, 3, 1, 2)



if __name__ == '__main__':

    input = torch.randn(1, 16, 512)
    _, index = input.min(dim=-1)
    print(index)
    print(index.shape)


    a = 1